/**
 * Copyright 2023-2023 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef FRAMEWORK_GRAPH_CORE_OPTIMIZE_FUSION_PATTERN_DEFINE_H
#define FRAMEWORK_GRAPH_CORE_OPTIMIZE_FUSION_PATTERN_DEFINE_H

#include <vector>
#include <map>

#include "framework/graph/core/cgraph/graph_fwd.h"
#include "framework/graph/core/node/node.h"

#include "framework/graph/utils/replacer/graph_replacer.h"

#include "framework/common/hcs_types.h"

namespace hiai {
using Id = int;
using Type = std::vector<std::string>;
using NodeInputs = std::vector<Id>;
using NodeChecker = Status (*)(ge::Node&);

struct PatternNode {
    Id id_;
    Type types_;
    NodeInputs inputs_;
    NodeChecker checker_;
    // to be decided
    bool repeatable_ {false};

    PatternNode(
        Id id = 0, Type types = {}, NodeInputs inputs = {}, NodeChecker checker = nullptr, bool repeatable = false)
        : id_(id), types_(types), inputs_(inputs), checker_(checker), repeatable_(repeatable)
    {
    }
};

class HCS_API_EXPORT PatternMapping {
public:
    virtual ge::Node* Node(const Id& id) = 0;
    virtual std::unique_ptr<ge::GraphSrcBoundary> ToGraphSrcBoundary() const = 0;

    virtual ~PatternMapping() = default;
};

using PatternNodes = std::vector<PatternNode>;
using PatternInputs = std::vector<Id>;
using PatternOutput = std::vector<Id>;

class HCS_API_EXPORT PatternDefine {
public:
    explicit PatternDefine(std::vector<PatternNode> nodes, PatternInputs inputs, PatternOutput output);
    ~PatternDefine() = default;

    std::vector<std::string> AttentionTypes() const;

    std::unique_ptr<PatternMapping> Match(const ge::Node& outputNode, const ge::ComputeGraph& graph);

private:
    bool VerifyPattern() const;
    bool VerifyInputsHomology(PatternMapping* patternMapping);
    bool VerifyOutNode(const ge::Node& node);

    std::unique_ptr<PatternMapping> BfsFromOutput(const ge::ComputeGraph& graph, const ge::Node& outputNode);

private:
    std::vector<PatternNode> patternNodes_;
    std::map<Id, PatternNode*> patternNodeMap_;
    PatternInputs patternInputs_;
    PatternOutput patternOutput_;
};
} // namespace hiai

#endif // HIAI_FRAMEWORK_CORE_OPTIMIZE_PATTERN_DEFINE_H